# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains code to create and manage SageMaker ``LineageTrialComponent``."""
from __future__ import absolute_import

import logging

from typing import List

from sagemaker.core.apiutils import _base_types
from sagemaker.core.lineage.query import (
    LineageQuery,
    LineageFilter,
    LineageSourceEnum,
    LineageEntityEnum,
    LineageQueryDirectionEnum,
)
from sagemaker.core.lineage.artifact import Artifact


LOGGER = logging.getLogger("sagemaker")


class LineageTrialComponent(_base_types.Record):
    """An Amazon SageMaker, lineage trial component, which is part of a SageMaker lineage.

    A trial component is a stage in a trial.
    Trial components are created automatically within the SageMaker runtime and also can be
    created directly. To automatically associate trial components with a trial and experiment
    supply an experiment config when creating a job.
    For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html

    Attributes:
        trial_component_name (str): The name of the trial component. Generated by SageMaker from the
            name of the source job with a suffix specific to the type of source job.
            trial_component_arn (str): The ARN of the trial component.
        display_name (str): The name of the trial component that will appear in UI,
            such as SageMaker Studio.
        source (obj): A TrialComponentSource object with a source_arn attribute.
        status (str): Status of the source job.
        start_time (datetime): When the source job started.
        end_time (datetime): When the source job ended.
        creation_time (datetime): When the source job was created.
        created_by (obj): Contextual info on which account created the trial component.
        last_modified_time (datetime): When the trial component was last modified.
        last_modified_by (obj): Contextual info on which account last modified the trial component.
        parameters (dict): Dictionary of parameters to the source job.
        input_artifacts (dict): Dictionary of input artifacts.
        output_artifacts (dict): Dictionary of output artifacts.
        metrics (obj): Aggregated metrics for the job.
        parameters_to_remove (list): The hyperparameters to remove from the component.
        input_artifacts_to_remove (list): The input artifacts to remove from the component.
        output_artifacts_to_remove (list): The output artifacts to remove from the component.
        tags (List[dict[str, str]]): A list of tags to associate with the trial component.
    """

    trial_component_name = None
    trial_component_arn = None
    display_name = None
    source = None
    status = None
    start_time = None
    end_time = None
    creation_time = None
    created_by = None
    last_modified_time = None
    last_modified_by = None
    parameters = None
    input_artifacts = None
    output_artifacts = None
    metrics = None
    parameters_to_remove = None
    input_artifacts_to_remove = None
    output_artifacts_to_remove = None
    tags = None

    _boto_create_method: str = "create_trial_component"
    _boto_load_method: str = "describe_trial_component"
    _boto_update_method: str = "update_trial_component"
    _boto_delete_method: str = "delete_trial_component"

    _boto_update_members = [
        "trial_component_name",
        "display_name",
        "status",
        "start_time",
        "end_time",
        "parameters",
        "input_artifacts",
        "output_artifacts",
        "parameters_to_remove",
        "input_artifacts_to_remove",
        "output_artifacts_to_remove",
    ]
    _boto_delete_members = ["trial_component_name"]

    @classmethod
    def load(cls, trial_component_name: str, sagemaker_session=None) -> "LineageTrialComponent":
        """Load an existing trial component and return an ``TrialComponent`` object representing it.

        Args:
            trial_component_name (str): Name of the trial component
            sagemaker_session (sagemaker.session.Session): Session object which
                manages interactions with Amazon SageMaker APIs and any other
                AWS services needed. If not specified, one is created using the
                default AWS configuration chain.
        Returns:
            LineageTrialComponent: A SageMaker ``LineageTrialComponent`` object
        """
        trial_component = cls._construct(
            cls._boto_load_method,
            trial_component_name=trial_component_name,
            sagemaker_session=sagemaker_session,
        )
        return trial_component

    def pipeline_execution_arn(self) -> str:
        """Get the ARN for the pipeline execution associated with this trial component (if any).

        Returns:
            str: A pipeline execution ARN.
        """
        trial_component = self.load(
            trial_component_name=self.trial_component_name, sagemaker_session=self.sagemaker_session
        )

        if trial_component.source is None or trial_component.source["SourceArn"] is None:
            return None

        tags = self.sagemaker_session.sagemaker_client.list_tags(
            ResourceArn=trial_component.source["SourceArn"]
        )["Tags"]
        for tag in tags:
            if tag["Key"] == "sagemaker:pipeline-execution-arn":
                return tag["Value"]
        return None

    def dataset_artifacts(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
    ) -> List[Artifact]:
        """Use the lineage query to retrieve datasets that use this trial component.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a dataset.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
        )
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.trial_component_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )

        return [vertex.to_lineage_object() for vertex in query_result.vertices]

    def models(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
    ) -> List[Artifact]:
        """Use the lineage query to retrieve models that use this trial component.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a dataset.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]
        )
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.trial_component_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
